from functools import cached_property
from typing import Any, List, Optional

from truss.cli.logs.base_watcher import LogWatcher
from truss.remote.baseten.api import BasetenApi
from truss.remote.baseten.utils.status import MODEL_RUNNING_STATES

MAX_LOOK_BACK_MS = 1000 * 60 * 60  # 1 hour.


class ModelDeploymentLogWatcher(LogWatcher):
    _model_id: str
    _deployment_id: str
    _current_status: Optional[str] = None

    def __init__(self, api: BasetenApi, model_id: str, deployment_id: str):
        super().__init__(api)
        self._model_id = model_id
        self._deployment_id = deployment_id

    def before_polling(self) -> None:
        self._current_status = self._get_current_status()

    def fetch_logs(
        self, start_epoch_millis: Optional[int], end_epoch_millis: Optional[int]
    ) -> List[Any]:
        return self.api.get_model_deployment_logs(
            self._model_id, self._deployment_id, start_epoch_millis, end_epoch_millis
        )

    def get_start_epoch_ms(self, now_ms: int) -> Optional[int]:
        # NOTE(Tyron): If there can be multiple replicas,
        # we can't use a timestamp cursor to poll for logs.
        if not self._is_development:
            return super().get_start_epoch_ms(now_ms)

        # Cursor logic.

        if self._last_log_time_ms:
            return max(self._last_log_time_ms, now_ms - MAX_LOOK_BACK_MS)

        return None

    def should_poll_again(self) -> bool:
        return self._current_status in MODEL_RUNNING_STATES

    def _get_deployment(self) -> Any:
        return self.api.get_deployment(self._model_id, self._deployment_id)

    def _get_current_status(self) -> str:
        return self._get_deployment()["status"]

    @cached_property
    def _is_development(self) -> bool:
        return self._get_deployment()["is_development"]

    def post_poll(self) -> None:
        self._current_status = self._get_current_status()

    def after_polling(self) -> None:
        pass
